"""passlib.handlers.scram - hash for SCRAM credential storage"""
#=============================================================================
# imports
#=============================================================================
# core
import logging; log = logging.getLogger(__name__)
# site
# pkg
from lib.passlib.utils import consteq, saslprep, to_native_str, splitcomma
from lib.passlib.utils.binary import ab64_decode, ab64_encode
from lib.passlib.utils.compat import bascii_to_str, iteritems, u, native_string_types
from lib.passlib.crypto.digest import pbkdf2_hmac, norm_hash_name
import lib.passlib.utils.handlers as uh
# local
__all__ = [
    "scram",
]

#=============================================================================
# scram credentials hash
#=============================================================================
class scram(uh.HasRounds, uh.HasRawSalt, uh.HasRawChecksum, uh.GenericHandler):
    """This class provides a format for storing SCRAM passwords, and follows
    the :ref:`password-hash-api`.

    It supports a variable-length salt, and a variable number of rounds.

    The :meth:`~passlib.ifc.PasswordHash.using` method accepts the following optional keywords:

    :type salt: bytes
    :param salt:
        Optional salt bytes.
        If specified, the length must be between 0-1024 bytes.
        If not specified, a 12 byte salt will be autogenerated
        (this is recommended).

    :type salt_size: int
    :param salt_size:
        Optional number of bytes to use when autogenerating new salts.
        Defaults to 12 bytes, but can be any value between 0 and 1024.

    :type rounds: int
    :param rounds:
        Optional number of rounds to use.
        Defaults to 100000, but must be within ``range(1,1<<32)``.

    :type algs: list of strings
    :param algs:
        Specify list of digest algorithms to use.

        By default each scram hash will contain digests for SHA-1,
        SHA-256, and SHA-512. This can be overridden by specify either be a
        list such as ``["sha-1", "sha-256"]``, or a comma-separated string
        such as ``"sha-1, sha-256"``. Names are case insensitive, and may
        use :mod:`!hashlib` or `IANA <http://www.iana.org/assignments/hash-function-text-names>`_
        hash names.

    :type relaxed: bool
    :param relaxed:
        By default, providing an invalid value for one of the other
        keywords will result in a :exc:`ValueError`. If ``relaxed=True``,
        and the error can be corrected, a :exc:`~passlib.exc.PasslibHashWarning`
        will be issued instead. Correctable errors include ``rounds``
        that are too small or too large, and ``salt`` strings that are too long.

        .. versionadded:: 1.6

    In addition to the standard :ref:`password-hash-api` methods,
    this class also provides the following methods for manipulating Passlib
    scram hashes in ways useful for pluging into a SCRAM protocol stack:

    .. automethod:: extract_digest_info
    .. automethod:: extract_digest_algs
    .. automethod:: derive_digest
    """
    #===================================================================
    # class attrs
    #===================================================================

    # NOTE: unlike most GenericHandler classes, the 'checksum' attr of
    # ScramHandler is actually a map from digest_name -> digest, so
    # many of the standard methods have been overridden.

    # NOTE: max_salt_size and max_rounds are arbitrarily chosen to provide
    # a sanity check; the underlying pbkdf2 specifies no bounds for either.

    #--GenericHandler--
    name = "scram"
    setting_kwds = ("salt", "salt_size", "rounds", "algs")
    ident = u("$scram$")

    #--HasSalt--
    default_salt_size = 12
    max_salt_size = 1024

    #--HasRounds--
    default_rounds = 100000
    min_rounds = 1
    max_rounds = 2**32-1
    rounds_cost = "linear"

    #--custom--

    # default algorithms when creating new hashes.
    default_algs = ["sha-1", "sha-256", "sha-512"]

    # list of algs verify prefers to use, in order.
    _verify_algs = ["sha-256", "sha-512", "sha-224", "sha-384", "sha-1"]

    #===================================================================
    # instance attrs
    #===================================================================

    # 'checksum' is different from most GenericHandler subclasses,
    # in that it contains a dict mapping from alg -> digest,
    # or None if no checksum present.

    # list of algorithms to create/compare digests for.
    algs = None

    #===================================================================
    # scram frontend helpers
    #===================================================================
    @classmethod
    def extract_digest_info(cls, hash, alg):
        """return (salt, rounds, digest) for specific hash algorithm.

        :type hash: str
        :arg hash:
            :class:`!scram` hash stored for desired user

        :type alg: str
        :arg alg:
            Name of digest algorithm (e.g. ``"sha-1"``) requested by client.

            This value is run through :func:`~passlib.crypto.digest.norm_hash_name`,
            so it is case-insensitive, and can be the raw SCRAM
            mechanism name (e.g. ``"SCRAM-SHA-1"``), the IANA name,
            or the hashlib name.

        :raises KeyError:
            If the hash does not contain an entry for the requested digest
            algorithm.

        :returns:
            A tuple containing ``(salt, rounds, digest)``,
            where *digest* matches the raw bytes returned by
            SCRAM's :func:`Hi` function for the stored password,
            the provided *salt*, and the iteration count (*rounds*).
            *salt* and *digest* are both raw (unencoded) bytes.
        """
        # XXX: this could be sped up by writing custom parsing routine
        # that just picks out relevant digest, and doesn't bother
        # with full structure validation each time it's called.
        alg = norm_hash_name(alg, 'iana')
        self = cls.from_string(hash)
        chkmap = self.checksum
        if not chkmap:
            raise ValueError("scram hash contains no digests")
        return self.salt, self.rounds, chkmap[alg]

    @classmethod
    def extract_digest_algs(cls, hash, format="iana"):
        """Return names of all algorithms stored in a given hash.

        :type hash: str
        :arg hash:
            The :class:`!scram` hash to parse

        :type format: str
        :param format:
            This changes the naming convention used by the
            returned algorithm names. By default the names
            are IANA-compatible; possible values are ``"iana"`` or ``"hashlib"``.

        :returns:
            Returns a list of digest algorithms; e.g. ``["sha-1"]``
        """
        # XXX: this could be sped up by writing custom parsing routine
        # that just picks out relevant names, and doesn't bother
        # with full structure validation each time it's called.
        algs = cls.from_string(hash).algs
        if format == "iana":
            return algs
        else:
            return [norm_hash_name(alg, format) for alg in algs]

    @classmethod
    def derive_digest(cls, password, salt, rounds, alg):
        """helper to create SaltedPassword digest for SCRAM.

        This performs the step in the SCRAM protocol described as::

            SaltedPassword  := Hi(Normalize(password), salt, i)

        :type password: unicode or utf-8 bytes
        :arg password: password to run through digest

        :type salt: bytes
        :arg salt: raw salt data

        :type rounds: int
        :arg rounds: number of iterations.

        :type alg: str
        :arg alg: name of digest to use (e.g. ``"sha-1"``).

        :returns:
            raw bytes of ``SaltedPassword``
        """
        if isinstance(password, bytes):
            password = password.decode("utf-8")
        # NOTE: pbkdf2_hmac() will encode secret & salt using utf-8,
        #       and handle normalizing alg name.
        return pbkdf2_hmac(alg, saslprep(password), salt, rounds)

    #===================================================================
    # serialization
    #===================================================================

    @classmethod
    def from_string(cls, hash):
        hash = to_native_str(hash, "ascii", "hash")
        if not hash.startswith("$scram$"):
            raise uh.exc.InvalidHashError(cls)
        parts = hash[7:].split("$")
        if len(parts) != 3:
            raise uh.exc.MalformedHashError(cls)
        rounds_str, salt_str, chk_str = parts

        # decode rounds
        rounds = int(rounds_str)
        if rounds_str != str(rounds): # forbid zero padding, etc.
            raise uh.exc.MalformedHashError(cls)

        # decode salt
        try:
            salt = ab64_decode(salt_str.encode("ascii"))
        except TypeError:
            raise uh.exc.MalformedHashError(cls)

        # decode algs/digest list
        if not chk_str:
            # scram hashes MUST have something here.
            raise uh.exc.MalformedHashError(cls)
        elif "=" in chk_str:
            # comma-separated list of 'alg=digest' pairs
            algs = None
            chkmap = {}
            for pair in chk_str.split(","):
                alg, digest = pair.split("=")
                try:
                    chkmap[alg] = ab64_decode(digest.encode("ascii"))
                except TypeError:
                    raise uh.exc.MalformedHashError(cls)
        else:
            # comma-separated list of alg names, no digests
            algs = chk_str
            chkmap = None

        # return new object
        return cls(
            rounds=rounds,
            salt=salt,
            checksum=chkmap,
            algs=algs,
        )

    def to_string(self):
        salt = bascii_to_str(ab64_encode(self.salt))
        chkmap = self.checksum
        chk_str = ",".join(
            "%s=%s" % (alg, bascii_to_str(ab64_encode(chkmap[alg])))
            for alg in self.algs
        )
        return '$scram$%d$%s$%s' % (self.rounds, salt, chk_str)

    #===================================================================
    # variant constructor
    #===================================================================
    @classmethod
    def using(cls, default_algs=None, algs=None, **kwds):
        # parse aliases
        if algs is not None:
            assert default_algs is None
            default_algs = algs

        # create subclass
        subcls = super(scram, cls).using(**kwds)

        # fill in algs
        if default_algs is not None:
            subcls.default_algs = cls._norm_algs(default_algs)
        return subcls

    #===================================================================
    # init
    #===================================================================
    def __init__(self, algs=None, **kwds):
        super(scram, self).__init__(**kwds)

        # init algs
        digest_map = self.checksum
        if algs is not None:
            if digest_map is not None:
                raise RuntimeError("checksum & algs kwds are mutually exclusive")
            algs = self._norm_algs(algs)
        elif digest_map is not None:
            # derive algs list from digest map (if present).
            algs = self._norm_algs(digest_map.keys())
        elif self.use_defaults:
            algs = list(self.default_algs)
            assert self._norm_algs(algs) == algs, "invalid default algs: %r" % (algs,)
        else:
            raise TypeError("no algs list specified")
        self.algs = algs

    def _norm_checksum(self, checksum, relaxed=False):
        if not isinstance(checksum, dict):
            raise uh.exc.ExpectedTypeError(checksum, "dict", "checksum")
        for alg, digest in iteritems(checksum):
            if alg != norm_hash_name(alg, 'iana'):
                raise ValueError("malformed algorithm name in scram hash: %r" %
                                 (alg,))
            if len(alg) > 9:
                raise ValueError("SCRAM limits algorithm names to "
                                 "9 characters: %r" % (alg,))
            if not isinstance(digest, bytes):
                raise uh.exc.ExpectedTypeError(digest, "raw bytes", "digests")
            # TODO: verify digest size (if digest is known)
        if 'sha-1' not in checksum:
            # NOTE: required because of SCRAM spec.
            raise ValueError("sha-1 must be in algorithm list of scram hash")
        return checksum

    @classmethod
    def _norm_algs(cls, algs):
        """normalize algs parameter"""
        if isinstance(algs, native_string_types):
            algs = splitcomma(algs)
        algs = sorted(norm_hash_name(alg, 'iana') for alg in algs)
        if any(len(alg)>9 for alg in algs):
            raise ValueError("SCRAM limits alg names to max of 9 characters")
        if 'sha-1' not in algs:
            # NOTE: required because of SCRAM spec (rfc 5802)
            raise ValueError("sha-1 must be in algorithm list of scram hash")
        return algs

    #===================================================================
    # migration
    #===================================================================
    def _calc_needs_update(self, **kwds):
        # marks hashes as deprecated if they don't include at least all default_algs.
        # XXX: should we deprecate if they aren't exactly the same,
        #      to permit removing legacy hashes?
        if not set(self.algs).issuperset(self.default_algs):
            return True

        # hand off to base implementation
        return super(scram, self)._calc_needs_update(**kwds)

    #===================================================================
    # digest methods
    #===================================================================
    def _calc_checksum(self, secret, alg=None):
        rounds = self.rounds
        salt = self.salt
        hash = self.derive_digest
        if alg:
            # if requested, generate digest for specific alg
            return hash(secret, salt, rounds, alg)
        else:
            # by default, return dict containing digests for all algs
            return dict(
                (alg, hash(secret, salt, rounds, alg))
                for alg in self.algs
            )

    @classmethod
    def verify(cls, secret, hash, full=False):
        uh.validate_secret(secret)
        self = cls.from_string(hash)
        chkmap = self.checksum
        if not chkmap:
            raise ValueError("expected %s hash, got %s config string instead" %
                             (cls.name, cls.name))

        # NOTE: to make the verify method efficient, we just calculate hash
        # of shortest digest by default. apps can pass in "full=True" to
        # check entire hash for consistency.
        if full:
            correct = failed = False
            for alg, digest in iteritems(chkmap):
                other = self._calc_checksum(secret, alg)
                # NOTE: could do this length check in norm_algs(),
                # but don't need to be that strict, and want to be able
                # to parse hashes containing algs not supported by platform.
                # it's fine if we fail here though.
                if len(digest) != len(other):
                    raise ValueError("mis-sized %s digest in scram hash: %r != %r"
                                     % (alg, len(digest), len(other)))
                if consteq(other, digest):
                    correct = True
                else:
                    failed = True
            if correct and failed:
                raise ValueError("scram hash verified inconsistently, "
                                 "may be corrupted")
            else:
                return correct
        else:
            # XXX: should this just always use sha1 hash? would be faster.
            # otherwise only verify against one hash, pick one w/ best security.
            for alg in self._verify_algs:
                if alg in chkmap:
                    other = self._calc_checksum(secret, alg)
                    return consteq(other, chkmap[alg])
            # there should always be sha-1 at the very least,
            # or something went wrong inside _norm_algs()
            raise AssertionError("sha-1 digest not found!")

    #===================================================================
    #
    #===================================================================

#=============================================================================
# code used for testing scram against protocol examples during development.
#=============================================================================
##def _test_reference_scram():
##    "quick hack testing scram reference vectors"
##    # NOTE: "n,," is GS2 header - see https://tools.ietf.org/html/rfc5801
##    from passlib.utils.compat import print_
##
##    engine = _scram_engine(
##        alg="sha-1",
##        salt='QSXCR+Q6sek8bf92'.decode("base64"),
##        rounds=4096,
##        password=u("pencil"),
##    )
##    print_(engine.digest.encode("base64").rstrip())
##
##    msg = engine.format_auth_msg(
##        username="user",
##        client_nonce = "fyko+d2lbbFgONRv9qkxdawL",
##        server_nonce = "3rfcNHYJY1ZVvWVs7j",
##        header='c=biws',
##    )
##
##    cp = engine.get_encoded_client_proof(msg)
##    assert cp == "v0X8v3Bz2T0CJGbJQyF0X+HI4Ts=", cp
##
##    ss = engine.get_encoded_server_sig(msg)
##    assert ss == "rmF9pqV8S7suAoZWja4dJRkFsKQ=", ss
##
##class _scram_engine(object):
##    """helper class for verifying scram hash behavior
##    against SCRAM protocol examples. not officially part of Passlib.
##
##    takes in alg, salt, rounds, and a digest or password.
##
##    can calculate the various keys & messages of the scram protocol.
##
##    """
##    #=========================================================
##    # init
##    #=========================================================
##
##    @classmethod
##    def from_string(cls, hash, alg):
##        "create record from scram hash, for given alg"
##        return cls(alg, *scram.extract_digest_info(hash, alg))
##
##    def __init__(self, alg, salt, rounds, digest=None, password=None):
##        self.alg = norm_hash_name(alg)
##        self.salt = salt
##        self.rounds = rounds
##        self.password = password
##        if password:
##            data = scram.derive_digest(password, salt, rounds, alg)
##            if digest and data != digest:
##                raise ValueError("password doesn't match digest")
##            else:
##                digest = data
##        elif not digest:
##            raise TypeError("must provide password or digest")
##        self.digest = digest
##
##    #=========================================================
##    # frontend methods
##    #=========================================================
##    def get_hash(self, data):
##        "return hash of raw data"
##        return hashlib.new(iana_to_hashlib(self.alg), data).digest()
##
##    def get_client_proof(self, msg):
##        "return client proof of specified auth msg text"
##        return xor_bytes(self.client_key, self.get_client_sig(msg))
##
##    def get_encoded_client_proof(self, msg):
##        return self.get_client_proof(msg).encode("base64").rstrip()
##
##    def get_client_sig(self, msg):
##        "return client signature of specified auth msg text"
##        return self.get_hmac(self.stored_key, msg)
##
##    def get_server_sig(self, msg):
##        "return server signature of specified auth msg text"
##        return self.get_hmac(self.server_key, msg)
##
##    def get_encoded_server_sig(self, msg):
##        return self.get_server_sig(msg).encode("base64").rstrip()
##
##    def format_server_response(self, client_nonce, server_nonce):
##        return 'r={client_nonce}{server_nonce},s={salt},i={rounds}'.format(
##            client_nonce=client_nonce,
##            server_nonce=server_nonce,
##            rounds=self.rounds,
##            salt=self.encoded_salt,
##            )
##
##    def format_auth_msg(self, username, client_nonce, server_nonce,
##                        header='c=biws'):
##        return (
##            'n={username},r={client_nonce}'
##                ','
##            'r={client_nonce}{server_nonce},s={salt},i={rounds}'
##                ','
##            '{header},r={client_nonce}{server_nonce}'
##            ).format(
##                username=username,
##                client_nonce=client_nonce,
##                server_nonce=server_nonce,
##                salt=self.encoded_salt,
##                rounds=self.rounds,
##                header=header,
##                )
##
##    #=========================================================
##    # helpers to calculate & cache constant data
##    #=========================================================
##    def _calc_get_hmac(self):
##        return get_prf("hmac-" + iana_to_hashlib(self.alg))[0]
##
##    def _calc_client_key(self):
##        return self.get_hmac(self.digest, b("Client Key"))
##
##    def _calc_stored_key(self):
##        return self.get_hash(self.client_key)
##
##    def _calc_server_key(self):
##        return self.get_hmac(self.digest, b("Server Key"))
##
##    def _calc_encoded_salt(self):
##        return self.salt.encode("base64").rstrip()
##
##    #=========================================================
##    # hacks for calculated attributes
##    #=========================================================
##
##    def __getattr__(self, attr):
##        if not attr.startswith("_"):
##            f = getattr(self, "_calc_" + attr, None)
##            if f:
##                value = f()
##                setattr(self, attr, value)
##                return value
##        raise AttributeError("attribute not found")
##
##    def __dir__(self):
##        cdir = dir(self.__class__)
##        attrs = set(cdir)
##        attrs.update(self.__dict__)
##        attrs.update(attr[6:] for attr in cdir
##                     if attr.startswith("_calc_"))
##        return sorted(attrs)
##    #=========================================================
##    # eoc
##    #=========================================================

#=============================================================================
# eof
#=============================================================================
